import numpy as np
from utils.TE import TrafficEngineering
import math
import gurobipy as grb
from root_const import ROOT_PATH

class Routing(TrafficEngineering):
    """
    Desc: 直接分流
    """
    _w_routing = {}
    
    def __init__(self, topology_pods):
        """
        Desc: 做routing前，除了需要加载基类的配置读取，配置其他参数
              还需要添加前面的拓扑工程计算出来的pods间连接数
        Inputs:
            - topology_pods(2-dimension list): Number of s_i egress links connected to igress links of s_j.
        """
        self._topology_pods = topology_pods


    
    def traffic_engineering(self, actual_traffic, a_link_bandwidth):
        """用gurobi来解
        """
        if not isinstance(actual_traffic, np.ndarray):
            traffic = np.array(actual_traffic)
        else:
            traffic = actual_traffic

        pods_num = self._pods_num
        capacity = self._topology_pods * a_link_bandwidth

        m = grb.Model('traffic_engineering_grb')
        m.Params.OutputFlag = 0
        mlu = m.addVar(lb = 0, vtype = grb.GRB.CONTINUOUS, name='mlu')

        name_w_ikj = [
            f'w_{i}_{k}_{j}'
            for i in range(1, pods_num + 1)
            for k in range(pods_num + 1)
            for j in range(1, pods_num + 1)
            if i != k and i != j and j != k
        ]
        w = m.addVars(name_w_ikj, lb = 0, ub = 1, vtype = grb.GRB.CONTINUOUS, name='w')

        # 没有链路的时候不能用流量
        m.addConstrs(
                grb.quicksum(
                    w[f'w_{k}_{i}_{j}'] + w[f'w_{i}_{j}_{k}']
                    if k != 0 else w[f'w_{i}_{0}_{j}']
                    for k in range(0, pods_num + 1)
                    if k != i and k != j
                ) == 0
                for i in range(1, pods_num + 1)
                for j in range(1, pods_num + 1)
                if i != j and capacity[i - 1][j - 1] == 0 
            )

        # summation(w_ikj) = 1
        m.addConstrs(
            (grb.quicksum(
                w[f'w_{i}_{k}_{j}'] for k in range(0, pods_num + 1)
                if k != i and k != j
            ) == 1
            for i in range(1, pods_num + 1)
            for j in range(1, pods_num + 1)
            if i != j),
            name = 'sumOneConstrs'
        )
        # summation(T_ij*w_) <= u * capacity
        if len(traffic.shape) == 2:
            # 单个流量矩阵下的约束
            m.addConstrs(
                grb.quicksum(
                    traffic[k - 1][j - 1] * w[f'w_{k}_{i}_{j}']
                    + traffic[i - 1][k - 1] * w[f'w_{i}_{j}_{k}']
                    if k != 0 else traffic[i - 1][j - 1] * w[f'w_{i}_{0}_{j}']
                    for k in range(0, pods_num + 1)
                    if k != i and k != j
                ) <= mlu * capacity[i - 1][j - 1]
                for i in range(1, pods_num + 1)
                for j in range(1, pods_num + 1)
                if i != j
            )
        elif len(traffic.shape) == 3:
            # 用多个流量矩阵来约束
            for t in traffic:
                m.addConstrs(
                    grb.quicksum(
                        t[k - 1][j - 1] * w[f'w_{k}_{i}_{j}']
                        + t[i - 1][k - 1] * w[f'w_{i}_{j}_{k}']
                        if k != 0 else t[i - 1][j - 1] * w[f'w_{i}_{0}_{j}']
                        for k in range(0, pods_num + 1)
                        if k != i and k != j
                    ) <= mlu * capacity[i - 1][j - 1]
                    for i in range(1, pods_num + 1)
                    for j in range(1, pods_num + 1)
                    if i != j
                )

        m.setObjective(mlu, grb.GRB.MINIMIZE)
        # m.write('debug.lp')
        m.optimize()
        if m.status == grb.GRB.Status.OPTIMAL:
            # print(m.objVal)
            solution = m.getAttr('X', w)
            w_routing = {}
            for w_name in name_w_ikj:
                w_routing[w_name] = solution[w_name]
            self._w_routing = w_routing
            return w_routing, m.objVal
        else:
            print('No solution')


    def routing(self, pods_num, a_link_bandwidth, traffic):
        self._pods_num = pods_num
        w_routing, mlu = self.traffic_engineering(traffic, a_link_bandwidth)
        return w_routing

if __name__ == "__main__":
    pass
